/**
 * Copyright (c) 2024 Huawei Technologies Co., Ltd.
 * This file is a part of the CANN Open Software.
 * Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

/*!
 * \file kernel_kfc.h
 * \brief
 */
#ifndef LIB_MATMUL_KERNEL_KFC_H
#define LIB_MATMUL_KERNEL_KFC_H

#if ASCENDC_CPU_DEBUG
#include <cstring>
#include <unistd.h>
#endif

#include "kernel_operator.h"
#include "lib/matmul/matmul_client.h"
#include "lib/matmul/matmul_server.h"
namespace AscendC {
class KfcServer { // AIC side
public:
    __aicore__ inline void Init(GM_ADDR workspaceGM)
    {
        ASSERT(workspaceGM != nullptr && "workspaceGM cannot be nullptr when init kfc server");

        workspace = workspaceGM;
        quitSize = 0;
        for (int32_t i = 0; i < MIX_NUM; i++) {
            kfcCommSrv[i].Init(workspace, i); // Initialize the message queue on the server.
        }
    }

    __aicore__ inline bool isRun()
    {
        // The function exits when all AIVs exit. The client sends a Quit message when the destructor ends.
        return quitSize < MIX_NUM;
    }

    template <class T, class... Args> __aicore__ inline void Run(T& a, Args&&... b)
    {
        TRACE_START(TraceId::KFC_SERVER_RUN);
        auto ptr = kfcCommSrv;
        __gm__ KfcMsg* msg;
        bool ret = true;
        for (int i = 0; i < MIX_NUM;) { // Get messages of each AIV core in polling mode.
            TRACE_START(TraceId::KFC_SERVER_REV_MSG);
            msg = ptr->RcvMessage();
            TRACE_STOP(TraceId::KFC_SERVER_REV_MSG);
            if (msg) {
                // The check message is public
                TRACE_START(TraceId::KFC_SERVER_PROCESS_MSG);
                auto funID = KfcMsgGetFunID(msg->head);
                auto srvID = static_cast<KFC_Enum>(static_cast<uint16_t>(funID) &
                    static_cast<uint16_t>(KFC_Enum::SERVICE_ID_MASK));
                bool freeMsg = true;
                if (srvID == KFC_Enum::SERVICE_ID_MATMUL) {
                    ret = RunAux(i, msg, funID, freeMsg, a, b...);
                } else if (srvID == KFC_Enum::SERVICE_ID_SCM) {
                    if (funID == KFC_Enum::SCMFUN_GM2L1) {
                        ScmDataCopy(&msg->buffer);
                    } else if (funID == KFC_Enum::SCMFUN_GM2L1ND2NZ) {
                        ScmDataCopyND2NZ(&msg->buffer);
                    }
                    if (unlikely(msg->ubAddr >= 0)) {
                        ptr->FreeUB(msg->ubAddr);
                    }
                } else if (funID == KFC_Enum::SERVICE_QUIT) {
                    quitSize++;
                } else {
                    ASSERT("unsupported service id !");
                }
                if (freeMsg) {
                    ptr->FreeMessage(msg); // Move the message backward by one after the message processed.
                    TRACE_STOP(TraceId::KFC_SERVER_PROCESS_MSG);
                } else {
                    ptr->RollBackMsg();
                    i++;
                    ptr++;
                    continue;
                }
            }
            if (ret) { // =false, lock a queue and must wait for release.
                i++;
                ptr++;
            }
        }
        TRACE_STOP(TraceId::KFC_SERVER_RUN);
    }

    template <class T, class... Args> __aicore__ inline void InitObj(TPipe* tpipe, T& a, Args&&... b)
    {
        if constexpr (sizeof(T) == sizeof(void*)) { // Skip previous invalid pointer for compatibility
            InitObj(b...);
        } else {
            ASSERT(kfcCommSrv != nullptr && "kfc comm server cannot be nullptr when init obj");
            auto ptr = kfcCommSrv;
            for (int i = 0; i < MIX_NUM; i++, ptr++) {
                InitObjAux(tpipe, ptr, i, 0, a, b...);
            }
        }
    }

    __aicore__ inline void Quit()
    {}

    template <class T, class... Args> __aicore__ static inline constexpr bool isTiling()
    {
        return sizeof(T) == sizeof(void*);
    }

    template <class T, class... Args> __aicore__ static T* GetTiling(T* t, Args&&... b)
    {
        return t;
    }

private:
    template <class T, class... Args>
    __aicore__ inline bool RunAuxSkip(int subBlockID, __gm__ KfcMsg* msg, KFC_Enum funID, bool& freeMsg,
        T& a, Args&&... b)
    {
        return RunAux(subBlockID, msg, funID, freeMsg, b...);
    }
    template <class T, class... Args>
    __aicore__ inline bool RunAux(int subBlockID, __gm__ KfcMsg* msg, KFC_Enum funID, bool& freeMsg, T& a, Args&&... b)
    {
        ASSERT(msg != nullptr && "msg cannot be nullptr when kfc server run aux");
        ASSERT(subBlockID >= 0 && subBlockID < MIX_NUM && "sub block id should be [0, MIX_NUM)");
        if (a.mm.mm[0].IsSharedMatmul()) {
            if (a.mm.mm[0].GetInstID() == KfcMsgGetInstID(msg->head)) {
                if (a.mm.mm[0].ProcessIbShareSync(funID, freeMsg, lastMsgId, subBlockID)) {
                    return true;
                }
                freeMsg = true;
                a.mm.mm[0].SetSubBlockIdx(static_cast<uint8_t>(subBlockID));
                return a.mm.mm[0].Process(msg, funID);
            } else if constexpr (sizeof...(b) == 0) {
                ASSERT(0);
                return true;
            } else if constexpr (isTiling<Args...>()) {
                if constexpr (sizeof...(b) > 1) {
                    return RunAuxSkip(subBlockID, msg, funID, freeMsg, b...);
                }
            } else if constexpr (sizeof...(b) >= 1) {
                return RunAux(subBlockID, msg, funID, freeMsg, b...);
            }
            return true;
        } else {
            if (a.mm.mm[subBlockID].GetInstID() == KfcMsgGetInstID(msg->head)) {
                if (a.mm.mm[subBlockID].ProcessIbShareSync(funID, freeMsg, lastMsgId, subBlockID)) {
                    return true;
                }
                freeMsg = true;
                a.mm.mm[subBlockID].SetSubBlockIdx(static_cast<uint8_t>(subBlockID));
                return a.mm.mm[subBlockID].Process(msg, funID);
            } else if constexpr (sizeof...(b) == 0) {
                ASSERT(0);
                return true;
            } else if constexpr (isTiling<Args...>()) {
                if constexpr (sizeof...(b) > 1) {
                    return RunAuxSkip(subBlockID, msg, funID, freeMsg, b...);
                }
            } else if constexpr (sizeof...(b) >= 1) {
                return RunAux(subBlockID, msg, funID, freeMsg, b...);
            }
            return true;
        }
    }

    template <class T, class... Args>
    __aicore__ inline void InitObjAuxSkip(TPipe* tpipe, KfcCommServer* kfc, int subBlockID, int instID, T* a,
        Args&&... b)
    {
        InitObjAux(tpipe, kfc, subBlockID, instID, b...);
    }

    template <class T, class... Args>
    __aicore__ inline void InitObjAux(TPipe *tpipe, KfcCommServer *kfc, int subBlockID, int instID, T &a, Args &&...b)
    {
        ASSERT(kfc != nullptr && "kfc cannot be nullptr when kfc server init obj aux");
        ASSERT(subBlockID >= 0 && subBlockID < MIX_NUM && "sub block id should be [0, MIX_NUM)");
        ASSERT(tpipe != nullptr);
        ASSERT(instID >= 0 && instID < MAX_MATMUL_OBJ && "matmul instID id be  [0, MAX_MATMUL_OBJ)");

        if constexpr (sizeof...(b) == 0) {
            if (a.mm.mm[0].IsSharedMatmul()) {
                if (subBlockID == 0) {
                    a.mm.mm[0].InitKfc(tpipe, (void *)nullptr, kfc, instID, workspace);
                }
            } else {
                a.mm.mm[subBlockID].InitKfc(tpipe, (void *)nullptr, kfc, instID, workspace);
            }
        } else if constexpr (isTiling<Args...>()) {
            auto tiling = GetTiling(b...);
            if (a.mm.mm[0].IsSharedMatmul()) {
                if (subBlockID == 0) {
                    a.mm.mm[0].InitKfc(tpipe, (void *)tiling, kfc, instID, workspace);
                    if constexpr (sizeof...(b) > 1) {
                        InitObjAuxSkip(tpipe, kfc, subBlockID, instID + 1, b...);
                    }
                } else {
                    if constexpr (sizeof...(b) > 1) {
                        InitObjAuxSkip(tpipe, kfc, subBlockID, instID + 1, b...);
                    }
                }
            } else {
                a.mm.mm[subBlockID].InitKfc(tpipe, (void *)tiling, kfc, instID, workspace);
                if constexpr (sizeof...(b) > 1) {
                    InitObjAuxSkip(tpipe, kfc, subBlockID, instID + 1, b...);
                }
            }
        } else {
            a.mm.mm[subBlockID].InitKfc(tpipe, (void *)nullptr, kfc, instID, workspace);
            if constexpr (sizeof...(b) >= 1) {
                InitObjAux(tpipe, kfc, subBlockID, instID + 1, b...);
            }
        }
    }

    // Apply for two servers on the server. aic<->aiv 1:1
    KfcCommServer kfcCommSrv[MIX_NUM];
    GM_ADDR workspace;
    uint8_t quitSize;
    int lastMsgId = 1;
};

template <const MatmulConfig& MM_CFG = CFG_NORM>
constexpr bool IsSharedMatmul()
{
    return !MM_CFG.enableInit;
}
template <class A_TYPE, class B_TYPE, class C_TYPE, class BIAS_TYPE = C_TYPE,
    const MatmulConfig& MM_CFG = CFG_NORM, class MM_CB = matmul::MatmulCallBackFunc<nullptr, nullptr, nullptr>>
struct MatmulInstBase {
    __aicore__ inline MatmulInstBase(){};
};
template <class A_TYPE, class B_TYPE, class C_TYPE, class BIAS_TYPE, const MatmulConfig& MM_CFG, class MM_CB>
struct MatmulInstShared : MatmulInstBase<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE, MM_CFG, MM_CB> {
    __aicore__ inline MatmulInstShared(){};
    matmul::MatmulService<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE, MM_CFG, MM_CB> mm[1];
};
template <class A_TYPE, class B_TYPE, class C_TYPE, class BIAS_TYPE, const MatmulConfig& MM_CFG, class MM_CB>
struct MatmulInst : MatmulInstBase<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE, MM_CFG, MM_CB> {
    __aicore__ inline MatmulInst(){};
    matmul::MatmulService<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE, MM_CFG, MM_CB> mm[MIX_NUM];
};

template <bool SHARED, class A_TYPE, class B_TYPE, class C_TYPE, class BIAS_TYPE, const MatmulConfig& MM_CFG,
    class MM_CB>
struct MatmulInstAux {
    __aicore__ inline MatmulInstAux(){};
};

template <class A_TYPE, class B_TYPE, class C_TYPE, class BIAS_TYPE, const MatmulConfig& MM_CFG, class MM_CB>
struct MatmulInstAux<true, A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE, MM_CFG, MM_CB> {
    __aicore__ inline MatmulInstAux(){};
    using MATMUL = MatmulInstShared<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE, MM_CFG, MM_CB>;
};

template <class A_TYPE, class B_TYPE, class C_TYPE, class BIAS_TYPE, const MatmulConfig& MM_CFG, class MM_CB>
struct MatmulInstAux<false, A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE, MM_CFG, MM_CB> {
    __aicore__ inline MatmulInstAux(){};
    using MATMUL = MatmulInst<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE, MM_CFG, MM_CB>;
};

template <class A_TYPE, class B_TYPE, class C_TYPE, class BIAS_TYPE = C_TYPE, const MatmulConfig& MM_CFG = CFG_NORM,
    class MM_CB = matmul::MatmulCallBackFunc<nullptr, nullptr, nullptr>>
class MatmulServiceAux {
    using SrcT = typename A_TYPE::T;
    using SrcAT = typename A_TYPE::T;
    using SrcBT = typename B_TYPE::T;
    using DstT = typename C_TYPE::T;
    using BiasT = typename BIAS_TYPE::T;
    using handle = __gm__ MsgGroupSyncAux*;

public:
    __aicore__ inline MatmulServiceAux() {}
    typename MatmulInstAux<IsSharedMatmul<MM_CFG>(), A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE, MM_CFG, MM_CB>::MATMUL mm;

    // stub functions for MatmulImpl
    __aicore__ inline void Init(TCubeTiling* cubeTiling, TPipe* tpipe = nullptr){};

    __aicore__ inline void SetOrgShape(int orgM, int orgN, int orgK){};
    __aicore__ inline void SetOrgShape(int orgM, int orgN, int orgKa, int orgKb, int orgKc = 0){};
    __aicore__ inline void SetSingleShape(int singleM, int singleN, int singleK){};
    __aicore__ inline void SetTail(int tailM = -1, int tailN = -1, int tailK = -1){};

    __aicore__ inline void SetTensorA(const GlobalTensor<SrcAT>& gm, bool isTranspose = false){};

    __aicore__ inline void SetTensorAWithCopy(const GlobalTensor<SrcAT>& gm, const LocalTensor<SrcAT>& leftMatrix,
        bool isTranspose = false){};
    __aicore__ inline void SetTensorB(const GlobalTensor<SrcBT>& gm, bool isTranspose = false){};

    __aicore__ inline void SetTensorBWithCopy(const GlobalTensor<SrcBT>& gm, const LocalTensor<SrcBT>& righMatrix,
        bool isTranspose = false){};
    __aicore__ inline void SetBias(const GlobalTensor<BiasT>& biasGlobal){};
    __aicore__ inline void SetTensorA(const LocalTensor<SrcAT>& leftMatrix, bool isTranspose = false){};
    __aicore__ inline void SetTensorB(const LocalTensor<SrcBT>& righMatrix, bool isTranspose = false){};
    __aicore__ inline void SetBias(const LocalTensor<BiasT>& inputBias){};
    __aicore__ inline void SetTensorA(SrcAT aScalar){};
    __aicore__ inline void SetTensorB(SrcBT bScalar){};
    __aicore__ inline void ClearBias(){};
    __aicore__ inline void SetSelfDefineData(const uint64_t dataPtr) {}
    __aicore__ inline void SetUserDefInfo(const uint64_t tilingPtr) {}
    __aicore__ inline void SetQuantScalar(const uint64_t quantScalar) {}
    __aicore__ inline void SetQuantVector(const GlobalTensor<uint64_t>& quantTensor) {}
    template <class T> __aicore__ inline void SetWorkspace(__gm__ T* addr, int size) {};
    template <class T> __aicore__ inline void SetWorkspace(GlobalTensor<T>& addr){};
    __aicore__ inline void End(){};
    __aicore__ inline void SetHF32(bool enHF32 = false, int32_t transMode = 0){};

    template <bool sync = true> __aicore__ inline bool Iterate(bool enPartialSum = false)
    {
        return false;
    };
    template <bool sync = true>
    __aicore__ inline void IterateAll(const GlobalTensor<DstT>& gm, uint8_t enAtomic = 0,
        bool enSequentialWrite = false, bool waitIterateAll = false, bool fakeMsg = false){};
    template <bool sync = true>
    __aicore__ inline void IterateAll(const LocalTensor<DstT>& cMatrix, uint8_t enAtomic = 0){};
    __aicore__ inline void WaitIterateAll() {};
    template <bool sync = true, bool doPad = false>
    __aicore__ inline void GetTensorC(const LocalTensor<DstT>& c, uint8_t enAtomic = 0,
        bool enSequentialWrite = false, uint32_t height = 0, uint32_t width = 0, uint32_t srcGap = 0,
        uint32_t dstGap = 0) {};
    template <bool sync = true>
    __aicore__ inline void GetTensorC(const GlobalTensor<DstT>& gm, uint8_t enAtomic = 0,
        bool enSequentialWrite = false){};
    template <bool sync = true>
    __aicore__ inline void GetTensorC(const GlobalTensor<DstT> &c, const LocalTensor<DstT> &cLocal,
        uint8_t enAtomic = 0, bool enSequentialWrite = false) {};
    template <bool sync = true>
    __aicore__ inline GlobalTensor<DstT> GetTensorC(uint8_t enAtomic = 0, bool enSequentialWrite = false)
    {
        GlobalTensor<DstT> global;
        return global;
    };
    template <bool sync = true, bool waitIterateBatch = false>
    __aicore__ inline void IterateBatch(const GlobalTensor<DstT>& gm, uint32_t batchA, uint32_t batchB,
        bool enSequentialWrite, const uint32_t matrixStrideA = 0, const uint32_t matrixStrideB = 0,
        const uint32_t matrixStrideC = 0) {};
    template <bool sync = true>
    __aicore__ inline void IterateBatch(const LocalTensor<DstT>& ubCmatrix, uint32_t batchA, uint32_t batchB,
        bool enSequentialWrite, const uint32_t matrixStrideA = 0, const uint32_t matrixStrideB = 0,
        const uint32_t matrixStrideC = 0) {};
    template <bool sync = true, bool waitIterateBatch = false>
    __aicore__ inline void IterateNBatch(const uint32_t batchLoop, uint32_t batchA, uint32_t batchB,
        bool enSequentialWrite, const uint32_t matrixStrideA = 0, const uint32_t matrixStrideB = 0,
        const uint32_t matrixStrideC = 0) {};
    template <bool sync = true>
    __aicore__ inline GlobalTensor<DstT> GetBatchC(uint32_t batchA, uint32_t batchB, bool enSequentialWrite = false) {};
    template <bool sync = true, bool doPad = false>
    __aicore__ inline void GetBatchC(const LocalTensor<DstT>& c, uint32_t batchA, uint32_t batchB,
        bool enSequentialWrite = false, uint32_t height = 0, uint32_t width = 0, uint32_t srcGap = 0,
        uint32_t dstGap = 0) {};
    __aicore__ inline void WaitIterateBatch() {};
    __aicore__ inline void SetLocalWorkspace(const LocalTensor<uint8_t>& tmpBuffer) {};
    __aicore__ inline void AsyncGetTensorC(const LocalTensor<DstT>& c){};
    __aicore__ inline void WaitGetTensorC(){};
    template <bool isTurnOnDebug = true>
    __aicore__ inline MatrixOffset GetOffsetC()
    {
        if constexpr (isTurnOnDebug) {
            static_assert(!isTurnOnDebug, "unsupported!");
        }
    }
};

template <class T, class... Args>
__aicore__ inline void SetMatrixKfcSkip(TPipe* pipe, KfcCommClient* kfcClient, const int32_t instID, GM_ADDR workspace,
    T& mm, Args&&... b)
{
    SetMatrixKfc(pipe, kfcClient, instID, workspace, b...);
}

template <class T, class... Args>
__aicore__ inline void SetMatrixKfc(TPipe* pipe, KfcCommClient* kfcClient, const int32_t instID, GM_ADDR workspace,
    T& mm, Args&&... b)
{
    ASSERT((pipe != nullptr) && "pipe should not be nullptr.");
    ASSERT((kfcClient != nullptr) && "kfcClient should not be nullptr.");
    ASSERT((workspace != nullptr) && "workspace should not be nullptr.");

    if constexpr (sizeof...(b) == 0) {
        InitKfcClient(mm, (void*)nullptr, pipe, kfcClient, instID, workspace);
    } else if constexpr (KfcServer::isTiling<Args...>()) {
        auto tiling = KfcServer::GetTiling(b...);
        InitKfcClient(mm, tiling, pipe, kfcClient, instID, workspace);
        if constexpr (sizeof...(b) > 1) {
            SetMatrixKfcSkip(pipe, kfcClient, instID + 1, workspace, b...);
        }
    } else {
        InitKfcClient(mm, (void*)nullptr, pipe, kfcClient, instID, workspace);
        if constexpr (sizeof...(b) >= 1) {
            SetMatrixKfc(pipe, kfcClient, instID + 1, workspace, b...);
        }
    }
}
}; // namespace AscendC

#endif
